from math import nan
import math
from re import I
import sys, os
import cv2
from PIL import Image  
import pickle
import numpy as np
import glob
from numpy.core.fromnumeric import size, transpose
from numpy.lib.function_base import flip
import matplotlib.pyplot as plt
from numpy.lib.type_check import real
import seaborn as sns
from tkinter import Tcl
from ToF_Validator import set_box_color; sns.set_theme()
from time import process_time


def refine_image(input_tof_image,invalid_mask,index,input_image):
    
    #refine image
    img_arr = input_tof_image /10.0
    for i in range(len(input_tof_image)):
        if (invalid_mask[i]):
            img_arr[i] = nan
    img_arr = np.around(img_arr)
    img_arr=np.reshape(img_arr,(8,8))
    # img_arr =transpose(img_arr)
    #Effective zone orientation
    # img_arr = flip(img_arr, axis=0) 
    # img_arr = flip(img_arr, axis=1) 




    #normilize
    #img_arr = cv2.normalize(img_arr, None, 255, 0, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
    # if(img_arr.max()!=0):
    #     img_arr = np.round(img_arr/(img_arr.max()/255.0))
    # img_arr= 255 -img_arr
    # #export QT_DEBUG_PLUGINS
    # image = Image.fromarray(img_arr.astype(np.uint8) , 'P',)  #L (8-bit pixels, black and white) --- P (8-bit pixels, mapped to any other mode using a color palette)
    # #img  = Image.new( mode = "RGB", size = (width, height) )
    # image = image.resize((width, height), Image.ANTIALIAS)
    # image.save("/home/im/Desktop/ETH_repo/iman-ostovar/Data Log/LogResults/On flight test/pic_test"+"img"+str(index)+".png")
    # index +=1

    #save picturs
    fig= plt.figure(figsize=(20, 8))
    
    #camera image
    fig.add_subplot(122)
    plt.title('Camera')
    plt.imshow(input_image)
    plt.grid(b=None)
    plt.axis('off')

    #tof image
    fig.add_subplot(121)
    # cmap = sns.diverging_palette(
    #     10, 133, 100, 60, center='light', as_cmap=True)
    sns.heatmap(img_arr, annot=True, fmt=".0f",
                 annot_kws={"size": 18},vmin=0, vmax=10000)
    sns.set(font_scale=1.9)
    # for i in range(8):
    #     for j in range(8):
    #         value = round(input_image[j+8*i],1)
    #         plt.text(i+0.16, j+0.25, "(" + str(value) + ")", fontsize=9, color='orange')
    plt.text(8+0.25, 8+0.4, "(cm)", fontsize=15, color='orange')
    plt.title('ToF')

    fig.savefig((address2save +folder_name +"_N"+str(index)+".png"), format="png",dpi=200)
    plt.close(fig)
    

def ProcessImages(address, address2save):
    ts = []
    dis = []
    tar = []
    sts = []
    try:
        with open(os.path.join(address+"state_ToF.dat"), 'rb') as f:
            while True:
                try:
                    o = pickle.load(f)
                except EOFError:
                    break
                ts.append(o)
                try:
                    o = pickle.load(f)
                except EOFError:
                    break
                dis.append(o)
                try:
                    o = pickle.load(f)
                except EOFError:
                    break
                tar.append(o)
                try:
                    o = pickle.load(f)
                except EOFError:
                    break
                sts.append(o)
            print(str(len(ts))+" images has been read succesfully. FR:" +
                  str(round((1000*len(ts))/(ts[-1]-ts[0]), 2)))
    except:
        print("file not found!")

    distances = np.array(dis)
    targets = np.array(tar)
    status = np.array(sts)
    timestamps = np.array(ts)
    #

    #real_distances = np.reshape(distances.flatten(),(width,height,int(distances.size/64)),order='A')

    #clear invalid pixels
    invalid_mask = np.zeros((len(distances), len(distances[0])))
    for i in range(len(distances)):
        for j in range(len(distances[0])):
            if((status[i, j] != 5 and status[i, j] != 9) or targets[i, j] != 1):
                invalid_mask[i, j] = 1
                #print(str(i)+" "+str(j)+" "+str(status[i,j]))
                # invalid_pixel_num += 1
    #
    #read all images name
    allimages = [f for f in os.listdir(os.path.join(address,'images')) if os.path.isfile(os.path.join(address,'images', f))]
    for fname in allimages:
        image_ts = os.path.splitext(fname)[0]
        try :
            tof_indexs=(ts.index(int(image_ts)))
            refine_image(distances[tof_indexs, :], invalid_mask[tof_indexs, :], tof_indexs,cv2.imread(os.path.join(address,'images', fname)))
        except ValueError:  
            print(fname + " not Found in Tof Images")
            pass
        # else:
        #     print(fname + " not Found in Tof Images")
    # video = cv2.VideoWriter(address2save+folder_name+".avi", 0, 15, (width,height),False)    
    # all_angless = np.zeros((len(distances), 28))
    # for i in range(len(distances)):
    #     angles = refine_image(distances[i, :], invalid_mask[i, :], i)
    # cv2.destroyAllWindows()
    # video.release()

def Make_video(address):

    address = os.path.join(address2save,'*.png')
    file_list = Tcl().call('lsort', '-dict', glob.glob(address))
    height, width, layers = (cv2.imread(file_list[0])).shape
    size = (width,height)
    out = cv2.VideoWriter(os.path.join(address2save,(folder_name+'.avi')), cv2.VideoWriter_fourcc(*'MJPG'),5,size)
    
    for filename in file_list:
        print(filename)
        img = cv2.imread(filename)
        # height, width, layers = img.shape
        out.write(img)
    out.release()


if __name__ == "__main__":
    
    # folder_names = ["Hand"]
    folder_names = ["Column_1Direction",
                    "Hand",
                    "column_back&force"]
    # width = 800
    # height = 800
    for i in range(len(folder_names)):
        t1_start = process_time() 
        folder_name = folder_names[i]
        print(str(i)+"_"+folder_name)
        address = "../../../Data Log/object movement/"+folder_name+"/" 
        address2save = "../../../Data Log/LogResults/Syncronization/"+folder_name+"/"
        # ProcessImages(address,address2save)
        Make_video(address2save)
        t1_stop = process_time()


    

    # Stop the stopwatch / counter
   
    print("Elapsed time during the whole program in seconds:",
                                         t1_stop-t1_start) 


   
